Skip to content

JaxCNN#35

Open
andyyPark wants to merge 13 commits intoLSSTDESC:masterfrom
andyyPark:master
Open

JaxCNN#35
andyyPark wants to merge 13 commits intoLSSTDESC:masterfrom
andyyPark:master

Conversation

@andyyPark
Copy link
Copy Markdown
Member

@andyyPark andyyPark commented Sep 1, 2020

This method was inspired by #9 where I exploited the auto differentiable metrics from jax-cosmo library. Flax was used to implement a simple Convolutional Neural Network that assigns bins to galaxies.

My network was

  • optimized for the total 3x2 FOM
  • trained on riz bands

I have created a jupyter notebook, JaxCNN.ipynb, that walks through my code in the notebooks folder, but I still have yet to finish running the notebook on NERSC.

Below is an example of the binning generated for 4 bins:
4_riz

Scores.ipynb in the notebooks folder shows the plots of the metrics for a different number of bins.

FOM_3x2

FOM_3x2

FOM_DETF_3x2

FOM_DETF_3x2

SNR_3x2

SNR_3x2

@andyyPark andyyPark marked this pull request as draft September 1, 2020 07:01
@EiffL
Copy link
Copy Markdown
Member

EiffL commented Sep 1, 2020

@andyyPark Thanks for your entry! Ahaha, am I glad to see another JAX neural network approach! When you have them, could you add to your description a few metrics, I'm very curious to see how the CNN compares to the Dense network ;-)

@EiffL EiffL added the entry Challenge entry label Sep 1, 2020
@andyyPark andyyPark marked this pull request as ready for review September 8, 2020 03:02
@andyyPark
Copy link
Copy Markdown
Member Author

andyyPark commented Sep 15, 2020

Since my original submission, I have changed my original jaxCNN to maximize the FOM_DETF score, and have implemented ResNet50 (jaxResNet.py) using jax and jax-cosmo library. Although I don't have any scores to show, below is an example of the binning generated for 5 and 6 bins using the Buzzard dataset:

5 Bins (jaxCNN)

image

6 Bins (jaxCNN)

image

5 Bins (jaxResNet)

image

6 Bins (jaxResNet)

image

I haven't tried this yet but it seems like my original CNN network (jaxCNN.py) performs better than ResNet50 with epochs ~O(100).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

entry Challenge entry

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants